{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting data/train-images-idx3-ubyte.gz\n",
      "Extracting data/train-labels-idx1-ubyte.gz\n",
      "Extracting data/t10k-images-idx3-ubyte.gz\n",
      "Extracting data/t10k-labels-idx1-ubyte.gz\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    " Weakly Supervised Net (Global Average Pooling) with MNIST\n",
    " @Sungjoon Choi (sungjoon.choi@cpslab.snu.ac.kr\n",
    "\"\"\"\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "%matplotlib inline  \n",
    "\n",
    "mnist = input_data.read_data_sets('data/', one_hot=True)\n",
    "trainimgs   = mnist.train.images\n",
    "trainlabels = mnist.train.labels\n",
    "testimgs    = mnist.test.images\n",
    "testlabels  = mnist.test.labels\n",
    "\n",
    "ntrain      = trainimgs.shape[0]\n",
    "ntest       = testimgs.shape[0]\n",
    "dim         = trainimgs.shape[1]\n",
    "nout        = trainlabels.shape[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Net Ready\n"
     ]
    }
   ],
   "source": [
    "# Make multilayer perceptron \n",
    "\n",
    "weights = {\n",
    "    'wc1' : tf.Variable(tf.random_normal([3, 3, 1, 128], stddev=0.1)), \n",
    "    'wc2' : tf.Variable(tf.random_normal([3, 3, 128, 256], stddev=0.1)), \n",
    "    'out' : tf.Variable(tf.random_normal([256, 10], stddev=0.1))\n",
    "}\n",
    "biases = {\n",
    "    'bc1' : tf.Variable(tf.random_normal([128], stddev=0.1)),\n",
    "    'bc2' : tf.Variable(tf.random_normal([256], stddev=0.1)),\n",
    "    'out' : tf.Variable(tf.random_normal([10], stddev=0.1))\n",
    "}\n",
    "\n",
    "def mlp(_X, _W, _b, _keepprob):\n",
    "    # Reshape input\n",
    "    _input_r = tf.reshape(_X, shape=[-1, 28, 28, 1])\n",
    "    # Conv1 \n",
    "    _conv1 = tf.nn.relu(\n",
    "            tf.nn.bias_add(\n",
    "                tf.nn.conv2d(_input_r, _W['wc1'], strides=[1, 1, 1, 1], padding='SAME')\n",
    "            , _b['bc1'])\n",
    "        )\n",
    "    # Pool1\n",
    "    _pool1 = tf.nn.max_pool(_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')\n",
    "    # DropOut1\n",
    "    _layer1 = tf.nn.dropout(_pool1, _keepprob)\n",
    "    # Conv2\n",
    "    _conv2 = tf.nn.relu(\n",
    "            tf.nn.bias_add(\n",
    "                tf.nn.conv2d(_layer1, _W['wc2'], strides=[1, 1, 1, 1], padding='SAME')\n",
    "            , _b['bc2'])\n",
    "        )\n",
    "    # Pool2 (Global average pooling)\n",
    "    _pool2 = tf.nn.avg_pool(_conv2, ksize=[1, 14, 14, 1], strides=[1, 14, 14, 1], padding='SAME')\n",
    "    # DropOut2\n",
    "    _layer2 = tf.nn.dropout(_pool2, _keepprob)\n",
    "    # Vectorize\n",
    "    _dense = tf.reshape(_layer2, [-1, _W['out'].get_shape().as_list()[0]])\n",
    "    # FC1\n",
    "    _out   = tf.nn.softmax(tf.add(tf.matmul(_dense, _W['out']), _b['out']))\n",
    "    out = {\n",
    "        'input_r': _input_r, 'conv1': _conv1, 'pool1': _pool1, 'layer1': _layer1,\n",
    "        'conv2': _conv2, 'pool2': _pool2, 'layer2': _layer2, 'dense': _dense, \n",
    "        'out': _out\n",
    "    }\n",
    "    return out\n",
    "\n",
    "\n",
    "\n",
    "# Define Parameter\n",
    "learning_rate   = 0.001\n",
    "training_epochs = 10\n",
    "batch_size      = 100\n",
    "display_step    = 1\n",
    "\n",
    "# Define Functions\n",
    "x = tf.placeholder(tf.float32, [None, 784])\n",
    "y = tf.placeholder(tf.float32, [None, 10])\n",
    "keepprob = tf.placeholder(tf.float32)\n",
    "pred = mlp(x, weights, biases, keepprob)['out']\n",
    "cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))\n",
    "optm = tf.train.AdamOptimizer(learning_rate).minimize(cost)\n",
    "corr = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))\n",
    "accr = tf.reduce_mean(tf.cast(corr, tf.float32))\n",
    "init = tf.initialize_all_variables()\n",
    "print (\"Net Ready\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Let's do it!\n",
    "sess = tf.Session()\n",
    "sess.run(init)\n",
    "\n",
    "# Training cycle\n",
    "for epoch in range(training_epochs):\n",
    "    sum_cost = 0.\n",
    "    total_batch = int(mnist.train.num_examples/batch_size)\n",
    "    # Loop over all batches\n",
    "    for i in range(total_batch):\n",
    "        batch_xs, batch_ys = mnist.train.next_batch(batch_size)\n",
    "        # Fit training using batch data\n",
    "        sess.run(optm, feed_dict={x: batch_xs, y: batch_ys, keepprob: 0.7})\n",
    "        # Compute average loss\n",
    "        curr_cost = sess.run(cost, feed_dict={x: batch_xs, y: batch_ys, keepprob: 1.})\n",
    "        sum_cost = sum_cost + curr_cost\n",
    "    avg_cost = sum_cost / total_batch\n",
    "    \n",
    "    # Display logs per epoch step\n",
    "    if epoch % display_step == 0:\n",
    "        print (\"Epoch: %03d/%03d cost: %.9f\" % (epoch, training_epochs, avg_cost))\n",
    "        train_acc = sess.run(accr, feed_dict={x: batch_xs, y: batch_ys, keepprob: 1.})\n",
    "        print (\"Training accuracy: %.3f\" % (train_acc))\n",
    "        test_acc = sess.run(accr, feed_dict={x: testimgs, y: testlabels, keepprob: 1.})\n",
    "        print (\"Test accuracy: %.3f\" % (test_acc))\n",
    "\n",
    "print (\"Optimization Finished!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Get Random Image\n",
    "randidx = np.random.randint(ntest)\n",
    "testimg = testimgs[randidx:randidx+1, :]\n",
    "\n",
    "# Run Network \n",
    "inputimg = sess.run(mlp(x, weights, biases, keepprob)['input_r'],\n",
    "                feed_dict={x: testimg, keepprob: 1.})\n",
    "outval = sess.run(mlp(x, weights, biases, keepprob)['out'],\n",
    "                feed_dict={x: testimg, keepprob: 1.})\n",
    "camval = sess.run(mlp(x, weights, biases, keepprob)['conv2'],\n",
    "                feed_dict={x: testimg, keepprob: 1.})\n",
    "cweights = sess.run(weights['out'])\n",
    "\n",
    "# Plot original Image \n",
    "plt.matshow(inputimg[0, :, :, 0], cmap=plt.get_cmap('gray'))\n",
    "plt.title(\"Input image\")\n",
    "plt.colorbar()\n",
    "plt.show()\n",
    "    \n",
    "# Plot class activation maps \n",
    "fig, axs = plt.subplots(2, 5, figsize=(15, 6))    \n",
    "for i in range(10):\n",
    "    predlabel   = np.argmax(outval)\n",
    "    predweights = cweights[:, i:i+1]\n",
    "    camsum = np.zeros((14, 14))\n",
    "    for j in range(256):\n",
    "        camsum = camsum + predweights[j]*camval[0, :, :, j]\n",
    "    camavg = camsum / 256\n",
    "    # Plot \n",
    "    im = axs[int(i/5)][i%5].matshow(camavg, cmap=plt.get_cmap('gray'))\n",
    "    axs[int(i/5)][i%5].set_title((\"[%d] prob is %.3f\") % (i, outval[0, i]))\n",
    "    plt.draw()\n",
    "\n",
    "fig.subplots_adjust(right=0.8)\n",
    "cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])\n",
    "fig.colorbar(im, cax=cbar_ax)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.4.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
